{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# mGENRE for fairseq\n",
    "\n",
    "First make sure that you have [fairseq](https://github.com/pytorch/fairseq) installed.\n",
    "Since `fairseq` is going through breaking changes please install it from [this](https://github.com/nicola-decao/fairseq/tree/fixing_prefix_allowed_tokens_fn) fork using: \n",
    "```bash\n",
    "git clone --branch fixing_prefix_allowed_tokens_fn https://github.com/nicola-decao/fairseq\n",
    "cd fairseq\n",
    "pip install --editable ./\n",
    "``` \n",
    "as described in the [fairseq repository](https://github.com/pytorch/fairseq#requirements-and-installation) since `pip install fairseq` has issues. \n",
    "\n",
    "### Download\n",
    "* the pre-trained **model** [fairseq_multilingual_entity_disambiguation](https://dl.fbaipublicfiles.com/GENRE/fairseq_multilingual_entity_disambiguation.tar.gz);\n",
    "* the **prefix tree (trie)** from Wikipedia titles [titles_lang_all105_trie_with_redirect.pkl](http://dl.fbaipublicfiles.com/GENRE/titles_lang_all105_trie_with_redirect.pkl)---this is fast but memory inefficient prefix tree implemented with nested python `dict`. As an alternative, we have a prefix tree implemented with `marisa_trie` that is much more memory efficient but a little slower [titles_lang_all105_marisa_trie_with_redirect.pkl](http://dl.fbaipublicfiles.com/GENRE/titles_lang_all105_marisa_trie_with_redirect.pkl);\n",
    "* the dictionary to map the generated strings to **Wikidata identifiers** [lang_title2wikidataID-normalized_with_redirect](https://dl.fbaipublicfiles.com/GENRE/lang_title2wikidataID-normalized_with_redirect.pkl) (the inverse mapping is availabe here [wikidataID2lang_title-normalized_with_redirect](https://dl.fbaipublicfiles.com/GENRE/wikidataID2lang_title-normalized_with_redirect.pkl));\n",
    "* optionally, we can use a **mention table** to restrict the search space to a number of candidates [mention2wikidataID_with_titles_label_alias_redirect](https://dl.fbaipublicfiles.com/GENRE/mention2wikidataID_with_titles_label_alias_redirect.pkl).\n",
    "\n",
    "\n",
    "# mGENRE for transformers\n",
    "\n",
    "First make sure that you have [transformers](https://github.com/huggingface/transformers) >=4.2.0 installed. \n",
    "**NOTE: we used fairseq for all experiments in the paper. The huggingface/transformers models are obtained with a [conversion script](https://github.com/facebookresearch/GENRE/blob/main/scripts_genre/convert_bart_original_pytorch_checkpoint_to_pytorch.py).**\n",
    "\n",
    "Then load the trie and define the function to apply the constraints with the entities trie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# OPTIONAL:\n",
    "import sys\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from genre.trie import Trie, MarisaTrie\n",
    "\n",
    "with open(\"../data/lang_title2wikidataID-normalized_with_redirect.pkl\", \"rb\") as f:\n",
    "    lang_title2wikidataID = pickle.load(f)\n",
    "\n",
    "# fast but memory inefficient prefix tree (trie) -- it is implemented with nested python `dict`\n",
    "# NOTE: loading this map may take up to 10 minutes and occupy a lot of RAM!\n",
    "# with open(\"../data/titles_lang_all105_trie_with_redirect.pkl\", \"rb\") as f:\n",
    "#     trie = Trie.load_from_dict(pickle.load(f))\n",
    "\n",
    "# memory efficient but slower prefix tree (trie) -- it is implemented with `marisa_trie`\n",
    "with open(\"../data/titles_lang_all105_marisa_trie_with_redirect.pkl\", \"rb\") as f:\n",
    "    trie = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for pytorch/fairseq\n",
    "from genre.fairseq_model import mGENRE\n",
    "model = mGENRE.from_pretrained(\"../models/fairseq_multilingual_entity_disambiguation\").eval()\n",
    "\n",
    "# for huggingface/transformers\n",
    "# from genre.hf_model import mGENRE\n",
    "# model = mGENRE.from_pretrained(\"../models/hf_multilingual_entity_disambiguation\").eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and simply use `.sample` to make predictions constraining using `prefix_allowed_tokens_fn`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'text': 'Albert Einstein >> it', 'score': tensor(-0.0808)},\n",
       "  {'text': 'Albert Einstein (disambiguation) >> en', 'score': tensor(-1.0998)},\n",
       "  {'text': 'Alfred Einstein >> it', 'score': tensor(-1.4337)},\n",
       "  {'text': 'Alberto Einstein >> it', 'score': tensor(-1.4619)},\n",
       "  {'text': 'Einstein >> it', 'score': tensor(-1.5765)}]]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentences = [\"[START] Einstein [END] era un fisico tedesco.\"]\n",
    "# Italian for \"[START] Einstein [END] was a German physicist.\"\n",
    "\n",
    "model.sample(\n",
    "    sentences,\n",
    "    prefix_allowed_tokens_fn=lambda batch_id, sent: [\n",
    "        e for e in trie.get(sent.tolist())\n",
    "        if e < len(model.task.target_dictionary)\n",
    "        # for huggingface/transformers\n",
    "        # if e < len(model2.tokenizer) - 1\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Additionally, we can use the `lang_title2wikidataID` dictionary to map the generated strings to Wikidata identifiers via the function `text_to_id`. The boolean parameter `marginalise` enables the aggregation of scores by entity ID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'id': 'Q937',\n",
       "   'texts': ['Albert Einstein >> it',\n",
       "    'Alberto Einstein >> it',\n",
       "    'Einstein >> it'],\n",
       "   'scores': tensor([-0.0808, -1.4619, -1.5765]),\n",
       "   'score': tensor(-0.0884)},\n",
       "  {'id': 'Q60197',\n",
       "   'texts': ['Alfred Einstein >> it'],\n",
       "   'scores': tensor([-1.4337]),\n",
       "   'score': tensor(-3.2058)},\n",
       "  {'id': 'Q15990626',\n",
       "   'texts': ['Albert Einstein (disambiguation) >> en'],\n",
       "   'scores': tensor([-1.0998]),\n",
       "   'score': tensor(-3.6478)}]]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.sample(\n",
    "    sentences,\n",
    "    prefix_allowed_tokens_fn=lambda batch_id, sent: [\n",
    "        e for e in trie.get(sent.tolist())\n",
    "        if e < len(model.task.target_dictionary)\n",
    "        # for huggingface/transformers\n",
    "        # if e < len(model2.tokenizer) - 1\n",
    "    ],\n",
    "    text_to_id=lambda x: max(lang_title2wikidataID[tuple(reversed(x.split(\" >> \")))], key=lambda y: int(y[1:])),\n",
    "    marginalize=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similar to `GENRE` we can use a mention table to restrict the search space to a number of candidates. We need fist two addinional dictionaries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mapping between mentions and Wikidata IDs and number of times they appear on Wikipedia\n",
    "with open(\"../data/mention2wikidataID_with_titles_label_alias_redirect.pkl\", \"rb\") as f:\n",
    "    mention2wikidataID = pickle.load(f)\n",
    "    \n",
    "# mapping between wikidataIDs and (lang, title) in all languages\n",
    "with open(\"../data/wikidataID2lang_title-normalized_with_redirect.pkl\", \"rb\") as f:\n",
    "    wikidataID2lang_title = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "then let's build the temporary trie for the mention and run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'id': 'Q937',\n",
       "   'texts': ['Albert Einstein >> it',\n",
       "    'Alberto Einstein >> it',\n",
       "    'Einstein >> it'],\n",
       "   'scores': tensor([-0.0808, -1.4619, -1.5765]),\n",
       "   'score': tensor(-0.0884)},\n",
       "  {'id': 'Q60197',\n",
       "   'texts': ['Alfred Einstein >> it'],\n",
       "   'scores': tensor([-1.4337]),\n",
       "   'score': tensor(-3.2058)},\n",
       "  {'id': 'Q13426745',\n",
       "   'texts': ['Albert Einstein (album) >> it'],\n",
       "   'scores': tensor([-2.0844]),\n",
       "   'score': tensor(-5.8956)}]]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentences = [\"[START] Einstein [END] era un fisico tedesco.\"]\n",
    "# Italian for \"[START] Einstein [END] was a German physicist.\"\n",
    "\n",
    "# building a temporary trie for the mention (to the purpose of\n",
    "# demonstraing the use of the mention table, let's restrict the\n",
    "# prediction to only candidates in Italian!)\n",
    "trie_of_mention = Trie([\n",
    "    [2] + model.encode(f\"{name} >> {lang}\")[1:].tolist()\n",
    "    for qid in mention2wikidataID[\"Einstein\"]\n",
    "    for lang, name in wikidataID2lang_title.get(qid, [])\n",
    "    if lang == \"it\"\n",
    "])\n",
    "\n",
    "# getting predictions\n",
    "model.sample(\n",
    "    sentences,\n",
    "    prefix_allowed_tokens_fn=lambda batch_id, sent: [\n",
    "        e for e in trie_of_mention.get(sent.tolist())\n",
    "        if e < len(model.task.target_dictionary)\n",
    "        # for huggingface/transformers\n",
    "        # if e < len(model2.tokenizer) - 1\n",
    "    ],\n",
    "    text_to_id=lambda x: max(lang_title2wikidataID[tuple(reversed(x.split(\" >> \")))], key=lambda y: int(y[1:])),\n",
    "    marginalize=True,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
